Skip to main content

Variational Autoencoders

A Variational Autoencoder (VAE) is a generative model that uses neural networks to encode input data into a latent space and then decodes it back to reconstruct the original data. VAEs combine principles from deep learning and probabilistic graphical models, enabling unsupervised learning of complex data distributions.

Architecture

The VAE consists of three main components:

Encoder

  • Transforms input data xx into a latent representation zz.
  • Outputs the parameters of the approximate posterior distribution qϕ(zx)q_{\phi}(z|x), typically the mean μ\mu and the log-variance logσ2\log \sigma^2.
  • Implemented as a neural network parameterized by ϕ\phi.

Latent Space

  • A lower-dimensional space representing the encoded features of the input data.
  • Imposes a prior distribution p(z)p (z), usually a standard normal distribution N(0,I)\mathcal{N}(0, I).
  • Enables sampling and generation of new data instances.

Decoder

  • Reconstructs the input data from the latent representation zz.
  • Defines the likelihood pθ(xz)p_{\theta}(x|z) of the data given the latent variables.
  • Implemented as a neural network parameterized by θ\theta.

Mathematical Formulation

The VAE optimizes the Evidence Lower Bound (ELBO) on the marginal likelihood:

L(ϕ,θ;x)=Eqϕ(zx)[logpθ(xz)]KL(qϕ(zx)p(z))\mathcal{L}(\phi, \theta; x) = \mathrm{E}_{q_{\phi}(z|x)} \left[ \log p_{\theta}(x|z) \right] - \mathrm{KL}\left( q_{\phi}(z|x) \parallel p(z) \right)

Where:

  • qϕ(zx)q_{\phi}(z|x): Approximate posterior distribution.
  • pθ(xz)p_{\theta}(x|z): Likelihood of the data given the latent variables.
  • KL()\mathrm{KL}(\cdot \parallel \cdot): Kullback-Leibler divergence between two distributions.

Loss Function

The loss function combines two terms:

  1. Reconstruction Loss (Lrec\mathcal{L}_{\text{rec}}):

    Measures how well the decoder reconstructs the input data.

  2. Regularization Term (Lreg\mathcal{L}_{\text{reg}}):

    Encourages the latent distribution qϕ(zx)q_{\phi}(z|x) to be close to the prior p(z)p (z).

Reparameterization Trick

To enable backpropagation through stochastic variables, the reparameterization trick is used:

z=μ+σϵ,ϵN(0,I)z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)
  • Allows gradients to flow through μ\mu and σ\sigma during training.
  • \odot denotes element-wise multiplication.

Training Process

  1. Encoding:

    • Input data xx is passed through the encoder.
    • Outputs mean μ\mu and log-variance logσ2\log \sigma^2.
  2. Sampling:

    • Sample zz from the latent space using the reparameterization trick.
  3. Decoding:

    • Sampled zz is passed through the decoder to reconstruct x^\hat{x}.
  4. Loss Computation:

    • Calculate reconstruction loss and regularization term.
    • Combine them to form the total loss.
  5. Optimization:

    • Update the network parameters ϕ\phi and θ\theta using gradient descent.

Key Concepts

Variational Inference

  • A technique to approximate complex probability distributions.
  • Transforms inference into an optimization problem.

Kullback-Leibler Divergence

  • Measures the difference between two probability distributions.
  • Encourages the learned distribution to be similar to the prior.

Applications

  • Data Generation: Generate new data samples similar to the training data.
  • Anomaly Detection: Identify outliers by measuring reconstruction error.
  • Dimensionality Reduction: Compress data into a lower-dimensional latent space.
  • Image and Text Modeling: Generate realistic images or text sequences.

Extensions and Variants

Conditional VAE (CVAE)

  • Incorporates additional information yy into both the encoder and decoder.
  • Models the conditional distribution p(xz,y)p (x|z, y).

β-VAE

  • Introduces a hyperparameter β\beta to balance the reconstruction and regularization terms:

    L=Eqϕ(zx)[logpθ(xz)]βKL(qϕ(zx)p(z))\mathcal{L} = \mathrm{E}_{q_{\phi}(z|x)} \left[ \log p_{\theta}(x|z) \right] - \beta \, \mathrm{KL}\left( q_{\phi}(z|x) \parallel p(z) \right)
  • Encourages disentangled latent representations.

Mathematical Expressions

KL Divergence for Gaussian Distributions:

KL(qϕ(zx)p(z))=12i=1d(1+logσi2μi2σi2)\mathrm{KL}\left( q_{\phi}(z|x) \parallel p(z) \right) = -\frac{1}{2} \sum_{i=1}^{d} \left(1 + \log \sigma_i^2 - \mu_i^2 - \sigma_i^2\right)

Where:

  • dd: Dimensionality of the latent space.
  • μi\mu_i and σi2\sigma_i^2: Mean and variance of the ii-th latent dimension.

Implementation Example

Below is a simplified example of a VAE implemented using Python and PyTorch:

import torch
import torch.nn as nn
import torch.nn.functional as F

class VAE(nn.Module):
def __init__(self, input_dim, latent_dim):
super(VAE, self).__init__()
# Encoder layers
self.fc1 = nn.Linear(input_dim, 400)
self.fc_mu = nn.Linear(400, latent_dim)
self.fc_logvar = nn.Linear(400, latent_dim)
# Decoder layers
self.fc2 = nn.Linear(latent_dim, 400)
self.fc3 = nn.Linear(400, input_dim)

def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc_mu(h1), self.fc_logvar(h1)

def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + std * eps

def decode(self, z):
h2 = F.relu(self.fc2(z))
return torch.sigmoid(self.fc3(h2))

def forward(self, x):
mu, logvar = self.encode(x.view(-1, x.size(1)))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar

Advantages and Limitations

Advantages

  • Generative Capabilities: Can generate new data samples.
  • Unsupervised Learning: Learns without labeled data.
  • Continuous Latent Space: Enables smooth interpolation between data points.

Limitations

  • Blurriness in Outputs: Generated samples may lack sharpness.
  • Training Complexity: Requires careful tuning of hyperparameters.
  • Mode Collapse: May generate less diverse samples compared to other models like GANs.

Variational Autoencoder (VAE) on MNIST Using PyTorch

Below is a full example of implementing a Variational Autoencoder (VAE) on the MNIST dataset using PyTorch. The code includes data loading, model definition, training loop, and visualization of the reconstructed images.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
latent_dim = 20
batch_size = 128
learning_rate = 1e-3
num_epochs = 10

# MNIST dataset
transform = transforms.Compose([
transforms.ToTensor()
])

train_dataset = datasets.MNIST(root='data',
train=True,
transform=transform,
download=True)

test_dataset = datasets.MNIST(root='data',
train=False,
transform=transform)

train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)

test_loader = DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)

# VAE Model
class VAE(nn.Module):
def __init__(self, image_size=784, h_dim=400, z_dim=20):
super(VAE, self).__init__()
# Encoder layers
self.fc1 = nn.Linear(image_size, h_dim)
self.fc_mu = nn.Linear(h_dim, z_dim) # Mean of the latent space
self.fc_logvar = nn.Linear(h_dim, z_dim) # Log variance of the latent space
# Decoder layers
self.fc2 = nn.Linear(z_dim, h_dim)
self.fc3 = nn.Linear(h_dim, image_size)

def encode(self, x):
h = torch.relu(self.fc1(x))
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar

def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar) # Standard deviation
eps = torch.randn_like(std) # Random tensor
return mu + eps * std

def decode(self, z):
h = torch.relu(self.fc2(z))
x_reconst = torch.sigmoid(self.fc3(h))
return x_reconst

def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
x_reconst = self.decode(z)
return x_reconst, mu, logvar

model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Loss function
def loss_function(x_reconst, x, mu, logvar):
# Reconstruction loss (binary cross-entropy)
BCE = nn.functional.binary_cross_entropy(x_reconst, x, reduction='sum')
# KL divergence
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Total loss
return BCE + KLD

# Training loop
model.train()
for epoch in range(num_epochs):
train_loss = 0
for batch_idx, (images, _) in enumerate(train_loader):
images = images.view(-1, 784).to(device)
optimizer.zero_grad()
x_reconst, mu, logvar = model(images)
loss = loss_function(x_reconst, images, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()

avg_loss = train_loss / len(train_loader.dataset)
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

# Testing and visualization
model.eval()
with torch.no_grad():
# Get a batch of test images
test_images, _ = next(iter(test_loader))
test_images = test_images.view(-1, 784).to(device)
# Reconstruct images
x_reconst, _, _ = model(test_images)
x_reconst = x_reconst.view(-1, 1, 28, 28).cpu()
# Original images
original_images = test_images.view(-1, 1, 28, 28).cpu()

# Visualize the reconstructed images
n = 8 # Number of images to display
plt.figure(figsize=(15, 4))
for i in range(n):
# Original images
ax = plt.subplot(2, n, i + 1)
plt.imshow(original_images[i][0], cmap='gray')
ax.axis('off')
# Reconstructed images
ax = plt.subplot(2, n, i + 1 + n)
plt.imshow(x_reconst[i][0], cmap='gray')
ax.axis('off')
plt.show()

Explanation

  • Imports: The necessary libraries are imported, including torch, torchvision, and matplotlib.
  • Device Configuration: The code checks if a GPU is available and sets the device accordingly.
  • Hyperparameters: Key hyperparameters such as latent_dim, batch_size, learning_rate, and num_epochs are defined.
  • Data Loading: The MNIST dataset is loaded with appropriate transformations, and data loaders are created for training and testing.
  • Model Definition: A VAE class is defined, inheriting from nn.Module. It includes methods for encoding, reparameterization, decoding, and the forward pass.
    • Encoder: Maps input images to the latent space parameters (mu and logvar).
    • Reparameterization Trick: Samples z from the latent space using mu and logvar.
    • Decoder: Reconstructs the input image from the latent variable z.
  • Loss Function: Combines the reconstruction loss (binary cross-entropy) and the KL divergence to form the total loss.
  • Training Loop: The model is trained over the specified number of epochs. In each iteration:
    • The input images are flattened and moved to the device.
    • The model performs a forward pass to obtain the reconstructed images and latent variables.
    • The loss is computed and backpropagated.
    • The optimizer updates the model parameters.
  • Testing and Visualization:
    • The model switches to evaluation mode.
    • A batch of test images is passed through the model to obtain reconstructions.
    • Both original and reconstructed images are plotted using matplotlib for visual comparison.

Notes

  • Reparameterization Trick: Essential for allowing gradients to flow through stochastic nodes by expressing the sampling operation in terms of deterministic operations and a noise variable.
  • KL Divergence: Encourages the learned latent distribution to be close to the prior distribution (standard normal distribution in this case).
  • Reconstruction Loss: Measures how well the decoder reconstructs the input data; binary cross-entropy is suitable for binary images like MNIST.

Potential Extensions

  • Hyperparameter Tuning: Experiment with different latent dimensions, learning rates, and network architectures to improve performance.
  • Conditional VAE: Modify the model to condition on labels, allowing for class-conditional image generation.
  • Different Datasets: Apply the VAE to more complex datasets like CIFAR-10 or CelebA, with appropriate architectural changes.

Conclusion

Variational Autoencoders provide a robust framework for generative modeling and unsupervised learning. By combining neural networks with probabilistic inference, VAEs can capture complex data distributions and have become a fundamental tool in the field of deep learning.